Skip to content

OpenCL: add conv2d kernel #14403

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open

Conversation

rmatif
Copy link
Collaborator

@rmatif rmatif commented Jun 26, 2025

Following up on #14316 and #14388, this PR adds a direct conv2d kernel for OpenCL. To maximize performance, this kernel uses a mixed-precision approach: data is stored in local memory as FP16 to save bandwidth and the core operations are vectorized using float4 for higher throughput.
Because of this, a comparison with an indirect conv2d implementation is not based on identical precision and it's not a fair comparison. I thought that since this is mainly designed for Adreno GPUs, we could sacrifice some accuracy for the benefit of maximum performance, which is a significant bottleneck on these devices. As a result, some tests fail by a small margin due to the precision differences, hope it's still okay!

I am opening this PR to gather feedback and to see if this performance/accuracy trade-off is acceptable or not

Performance:

GPU Direct (GFLOPS) Indirect (GFLOPS) Speedup
Adreno 830 520.74 38.02 13.70x
Adreno 750 385.77 27.28 14.14x
Adreno 740 211.38 25.12 8.42x
Adreno 730 158.83 19.34 8.21x

@lhez @max-krasnyansky

@github-actions github-actions bot added the ggml changes relating to the ggml tensor library for machine learning label Jun 26, 2025
@etasnadi
Copy link
Contributor

It seems that your kernel is the opencl vectorized version of the vulkan kernel I proposed, but I do not see this kind of performance improvement on Vulkan over the indirect impl. You might want to disable vectorized access to see what causes the improvement.

@rmatif
Copy link
Collaborator Author

rmatif commented Jun 27, 2025

It seems that your kernel is the opencl vectorized version of the vulkan kernel I proposed, but I do not see this kind of performance improvement on Vulkan over the indirect impl. You might want to disable vectorized access to see what causes the improvement.

I have taken inspiration from your CUDA implementation, thanks for it!, so it's pretty much similar approach

After disabling vectorization, the scalar kernel achieves 182.18 GFLOPS on the Adreno 830. I think the significant speedup over the indirect implementation is mainly due to the current OpenCL backend being unoptimized, rather than any specific feature of the new kernel. The im2col kernel has poor memory access patterns and performs worse than the CPU implementation, while the subsequent mul_mat operation falls back to a generic f16 kernel which is not well-optimized, as only the q4_0 kernels show good looking performance as of now

@etasnadi
Copy link
Contributor

It seems that your kernel is the opencl vectorized version of the vulkan kernel I proposed, but I do not see this kind of performance improvement on Vulkan over the indirect impl. You might want to disable vectorized access to see what causes the improvement.

I have taken inspiration from your CUDA implementation, thanks for it!, so it's pretty much similar approach

After disabling vectorization, the scalar kernel achieves 182.18 GFLOPS on the Adreno 830. I think the significant speedup over the indirect implementation is mainly due to the current OpenCL backend being unoptimized, rather than any specific feature of the new kernel. The im2col kernel has poor memory access patterns and performs worse than the CPU implementation, while the subsequent mul_mat operation falls back to a generic f16 kernel which is not well-optimized, as only the q4_0 kernels show good looking performance as of now

It's good to know -- that could be the reason. I also observed this in Vulkan: the direct kernel is faster because the mul_mat kernel is not optimized well enough (at least not to my device) while the direct kernel is more of less more optimized to my device.

I also ported the direct kernel to CUDA and found that the indirect im2col&cuBLAS based mul_mat is ~33% faster than my direct kernel on Turing (the cuBLAS matmul is very highly optimized). I found this promising because there are lots of opportunities for optimization in the direct kernel (eliminating bank conflicts, warp-tiling, double buffering, faster computation of the offsets), so the direct kernel could become on par with the highly optimized indirect kernel in performance while not wasting lots of memory as im2col does.

@CISC CISC added the OpenCL Issues specific to the OpenCL backend label Jul 2, 2025
@ggerganov ggerganov requested a review from max-krasnyansky July 4, 2025 17:58
@rmatif
Copy link
Collaborator Author

rmatif commented Jul 9, 2025

@lhez Should I add test cases to test-backend-ops to make testing easier for you? I initially thought #14316 would be merged first, so I didn’t include them

@lhez
Copy link
Collaborator

lhez commented Jul 11, 2025

@lhez Should I add test cases to test-backend-ops to make testing easier for you? I initially thought #14316 would be merged first, so I didn’t include them

It should be good to keep it this way. I can apply this PR on top of #14316 in my local env.

@rmatif
Copy link
Collaborator Author

rmatif commented Jul 16, 2025

It was probably not a good idea to force input conversion to F16, as we are hitting memory bandwidth limitations quickly anyway. On stable-diffusion.cpp, there was no significant improvement during sampling when running under F16 or F32.

To fix this, I have added a separate kernel for F16 inputs, as well as a kernel that handles an F16 conv2d kernel with an F32 input (which stable-diffusion.cpp uses). Now all tests pass and sd.cpp works correctly. This kernel, combined with the tiled mul_mat, has led to a ~2.1x improvement in inference speed so far

Here are the results for now on Adreno 830:

  CONV_2D_DIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  3 runs - 369258.33 us/run - 137.42 GFLOP/run - 372.16 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    2992 runs -   385.54 us/run - 133.69 MFLOP/run - 346.77 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    2211 runs -   540.53 us/run - 135.78 MFLOP/run - 251.20 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):             24576 runs -    43.13 us/run - 642.82 kFLOP/run -  14.91 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     4786 runs -   519.87 us/run -  20.90 MFLOP/run -  40.20 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     8192 runs -   261.00 us/run -   2.78 MFLOP/run -  10.67 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     4489 runs -  2000.09 us/run -  22.28 MFLOP/run -  11.14 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    2601 runs -   450.84 us/run - 115.40 MFLOP/run - 255.98 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     436 runs -  2466.03 us/run - 923.24 MFLOP/run - 374.38 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  220 runs -  5290.88 us/run -   1.85 GFLOP/run - 349.45 GFLOPS

  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                        1 runs - 3611056.00 us/run - 137.42 GFLOP/run -  38.06 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   748 runs -  4435.67 us/run - 133.69 MFLOP/run -  30.14 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   737 runs -  4511.76 us/run - 135.78 MFLOP/run -  30.10 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    5120 runs -   219.79 us/run - 642.82 kFLOP/run -   2.92 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   1024 runs -  3116.75 us/run -  20.90 MFLOP/run -   6.70 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   1024 runs -  2700.94 us/run -   2.78 MFLOP/run -   1.03 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   1024 runs - 21760.91 us/run -  22.28 MFLOP/run -   1.02 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   867 runs -  2116.94 us/run - 115.40 MFLOP/run -  54.51 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   109 runs - 16969.63 us/run - 923.24 MFLOP/run -  54.41 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 55 runs - 50055.93 us/run -   1.85 GFLOP/run -  36.94 GFLOPS

CPU for comparaison:

  CONV_2D_DIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                  1 runs - 1144127.00 us/run - 137.42 GFLOP/run - 120.11 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     660 runs -  1605.56 us/run - 133.69 MFLOP/run -  83.27 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     649 runs -  1650.94 us/run - 135.78 MFLOP/run -  82.25 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):             16384 runs -    63.26 us/run - 642.82 kFLOP/run -  10.16 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     1149 runs -  1051.21 us/run -  20.90 MFLOP/run -  19.88 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     2873 runs -   491.97 us/run -   2.78 MFLOP/run -   5.66 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                      360 runs -  4436.29 us/run -  22.28 MFLOP/run -   5.02 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     770 runs -  1328.21 us/run - 115.40 MFLOP/run -  86.89 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                      99 runs - 10822.51 us/run - 923.24 MFLOP/run -  85.31 GFLOPS
  CONV_2D_DIRECT_IMPL(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   60 runs - 17910.78 us/run -   1.85 GFLOP/run - 103.23 GFLOPS

  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                        1 runs - 1166093.00 us/run - 137.42 GFLOP/run - 117.85 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   660 runs -  1518.26 us/run - 133.69 MFLOP/run -  88.06 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   649 runs -  1543.69 us/run - 135.78 MFLOP/run -  87.96 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    5120 runs -   220.06 us/run - 642.82 kFLOP/run -   2.92 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    766 runs -  1631.11 us/run -  20.90 MFLOP/run -  12.81 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   2048 runs -   939.86 us/run -   2.78 MFLOP/run -   2.96 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    360 runs -  9065.79 us/run -  22.28 MFLOP/run -   2.46 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   840 runs -  1302.27 us/run - 115.40 MFLOP/run -  88.62 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    90 runs - 11948.18 us/run - 923.24 MFLOP/run -  77.27 GFLOPS
  CONV_2D_INDIRECT_IMPL(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 55 runs - 19728.93 us/run -   1.85 GFLOP/run -  93.71 GFLOPS

@rmatif rmatif requested a review from lhez July 16, 2025 18:27
@rmatif rmatif force-pushed the opencl-add-conv2d branch from 8110881 to 8412441 Compare July 17, 2025 11:00
@lhez
Copy link
Collaborator

lhez commented Jul 18, 2025

@rmatif thank you and sorry for the delay: have been distracted. Do you think it would be better for me to also get this running with stable-difusion.cpp?

@etasnadi
Copy link
Contributor

@rmatif If subgroup ops are also available on Adreno, I suggest to include my update from last week. It significantly reduces the number of integer division and modulo ops by exchanging data between threads as these ops can be very slow on many devices.

https://github.com/ggml-org/llama.cpp/blob/f8295bfc76bb3774083cc55dd3a529a574fc68a3/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp

@rmatif
Copy link
Collaborator Author

rmatif commented Jul 18, 2025

@rmatif thank you and sorry for the delay: have been distracted. Do you think it would be better for me to also get this running with stable-difusion.cpp?

I think it's generally good practice to test the kernel on a real use-case scenario. For sdcpp, you just need to replace all three occurrences of ggml_conv_2d with ggml_conv_2d_direct in ggml_extend.hpp, as well as the one in preprocessing.hpp, and then build using -DSD_OPENCL=ON

Just let me know if you need any help or command-line example

@rmatif If subgroup ops are also available on Adreno, I suggest to include my update from last week. It significantly reduces the number of integer division and modulo ops by exchanging data between threads as these ops can be very slow on many devices.

https://github.com/ggml-org/llama.cpp/blob/f8295bfc76bb3774083cc55dd3a529a574fc68a3/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp

I generally try to avoid using subgroups, since the goal is to make this backend compatible with OpenCL 1.2. That said, I’ll take a look and see if I can add an alternative path using subgroups. Thanks for the suggestion

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning OpenCL Issues specific to the OpenCL backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants